In [64]:
import numpy as np
import pandas as pd
import json
import sys
import os
import matplotlib
matplotlib.use('Agg') 
import matplotlib.pyplot as plt
import seaborn as sns
import pdb

from util import utils as data_utils

%pylab inline
%matplotlib inline
plt.rcParams['figure.figsize'] = (10.0, 8.0) # set default size of plots
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'Blues'

# for auto-reloading external modules
# see http://stackoverflow.com/questions/1907993/autoreload-of-modules-in-ipython
%load_ext autoreload
%autoreload 2

json_file = './cifar_results/noise_40/bootstrap_var_lr_001/checkpoint_50.json'
FDIR = os.path.dirname(json_file)
NUM_CLASSIFY = 10
Populating the interactive namespace from numpy and matplotlib
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
In [ ]:
 
In [65]:
# Plot gradients norms for the entire learning process
grads_json_filename = os.path.join(FDIR, 'model_grads.json')
grads = [[], [], []]
grads_key = ['max_grad_w1_16', 'max_grad_w1_32', 'max_grad_w1_64']
if os.path.exists(grads_json_filename):
    with open(grads_json_filename, 'r') as fp:
        data = json.load(fp)
        for i, k in enumerate(grads_key):
            if data[0].get(k, None) is None:
                continue
            for batch_grads in data:
                grads[i].append(batch_grads[k])

def plot_grads(grads, title, x_label, y_label, figsize=(10, 8)):
    plt.figure(figsize=figsize)
    # plt.subplot(2, 1, 1)
    plt.plot(grads)
    plt.title(title)
    plt.ylabel(y_label)
    plt.xlabel(x_label)
    
for i, g in enumerate(grads):
    if len(g) > 0:
        plot_grads(g, grads_key[i], 'iterations', grads_key[i])
        # pass
In [66]:
with open(json_file, 'r') as fp:
    data = json.load(fp)
# Loss history might not be of equal length.
train_loss_hist = data['train_loss_history']
val_loss_hist = data['val_loss_history']

# pdb.set_trace()
def plot_loss_hist(loss_hist, title,):
    plt.figure(figsize=(5,4))
    plt.subplot(1, 1, 1)
    plt.plot(loss_hist)
    plt.title(title)  # Train Loss
    plt.ylabel('loss')
    plt.xlabel('time')
    plt.show()
    
plot_loss_hist(train_loss_hist, 'Train Loss')
plot_loss_hist(val_loss_hist, 'Val loss')

if data.get('crit1_loss_history', None) is not None:
    plot_loss_hist(data['crit1_loss_history'], 'Target criterion loss')

if data.get('crit2_loss_history', None) is not None and \
    len(data['crit2_loss_history']) > 0:
    plot_loss_hist(data['crit2_loss_history'], 'Pred criterion loss')

if data.get('pred_loss_history', None) is not None and \
    len(data['pred_loss_history']) > 0:
    plot_loss_hist(data['pred_loss_history'], 'Total Pred loss (beta*t + (1-beta)*p)')    

if data.get('beta_loss_history', None) is not None and \
    len(data['beta_loss_history']) > 0:
    plot_loss_hist(data['beta_loss_history'], 'Beta loss')
In [67]:
if data.get('KL_loss_history', None) is not None:
    # Loss history might not be of equal length.
    KL_loss_hist = data['KL_loss_history']

    plt.figure(figsize=(10,8))
    plt.plot(KL_loss_hist)
    plt.title('KL loss')
    plt.ylabel('loss')
    plt.xlabel('time')
    plt.show()
In [68]:
def get_conf(json_file, num_classes=26, json_key='conf'):
    with open(json_file, 'r') as fp:
        data = json.load(fp)
        conf = data.get(json_key, None)
    if conf is None:
        return
    # c1 = conf.split('\n')[1].split("]")[0].split("[ ")[1].split(" ")
    c1 = conf.split('\n')
    # print(c1)
    conf_mat, row_idx = np.zeros((num_classes, num_classes)), 0
    for i in c1:
        #pdb.set_trace()
        is_conf_row = False
        if ']' in i and '[[' in i:
            val = i.split(']')[0].split('[[')[1].split(' ')
            is_conf_row = True
        elif ']' in i and '[' in i:
            val = i.split(']')[0].split('[')[1].split(' ')
            is_conf_row = True
        if is_conf_row:
            col_idx = 0
            for v in val:
                if not len(v):
                    continue
                try:
                    conf_mat[row_idx, col_idx] = int(v)
                    col_idx = col_idx + 1
                except:
                    continue
            row_idx = row_idx + 1
    
    assert(row_idx == num_classes)
    conf_mat = conf_mat.astype(int)
    fdir = os.path.dirname(json_file)
    json_name = os.path.basename(json_file)[:-5]
    conf_file_name = fdir + '/' + 'conf_' + json_name + '.txt'
    np.savetxt(conf_file_name, conf_mat, fmt='%d', delimiter=', ')
    return conf_mat


def plot_conf(norm_conf):
  # Plot using seaborn
  # (this is style I used for ResNet matrix)
  plt.figure(figsize=(10,6))
  df_cm = pd.DataFrame(norm_conf)
  sns.heatmap(df_cm, annot=True, cmap="Blues")
  plt.show()
In [69]:
def get_sorted_checkpoints(fdir):
    # Checkpoint files are named as 'checkpoint_%d.json'
    checkpoint_map = {}
    for f in os.listdir(fdir):
        if f.endswith('json') and f.startswith('checkpoint'):
            checkpoint_num = int(f.split('checkpoint_')[-1].split('.')[0])
            checkpoint_map[checkpoint_num] = f
    sorted_checkpoints = []
    for k in sorted(checkpoint_map.keys()):
        v = checkpoint_map[k]
        sorted_checkpoints.append(v)
    return sorted_checkpoints
In [70]:
def best_f_scores(fdir, num_classes=5): 
    best_checkpoints = [None, None, None]
    best_3_fscores = [0, 0, 0]
    best_confs = [np.array(()), np.array(()), np.array(())]
    f1_weight_list = [1.0] * num_classes
    f1_weights = np.array(f1_weight_list)
    sorted_checkpoint_files = get_sorted_checkpoints(fdir)
    for f in sorted_checkpoint_files:
        json_file = fdir + '/' + f
        conf = get_conf(json_file, num_classes, json_key='val_conf')
        norm_conf = data_utils.normalize_conf(conf)
        f1 = data_utils.get_f1_score(conf, f1_weights)
        kappa = data_utils.computeKappa(conf)
        wt_f1 = data_utils.computeWeightedF1(conf)
        print('file: {}, f1: {:.3f}, kappa: {:.3f}, weighted-F1: {:.3f}'.format(
                f, f1, kappa, wt_f1))
        plot_conf(norm_conf)
        max_idx = -1
        for i in range(len(best_3_fscores)):
            if best_3_fscores[i] > f1:
                break
            max_idx = i
        for j in range(max_idx):
            best_3_fscores[j] = best_3_fscores[j+1]
            best_confs[j] = best_confs[j+1]
            best_checkpoints[j] = best_checkpoints[j+1]

        best_3_fscores[max_idx] = f1
        best_confs[max_idx] = conf
        best_checkpoints[max_idx] = f

    return best_3_fscores, best_confs, best_checkpoints
In [71]:
def plot_train_conf(fdir, num_classes=5):
    sorted_checkpoint_files = get_sorted_checkpoints(fdir)
    if len(sorted_checkpoint_files) > 0:
        last_checkpoint = sorted_checkpoint_files[-1]
        json_file = fdir + '/' + last_checkpoint
        conf = get_conf(json_file, num_classes=num_classes, json_key='train_conf')
        print(conf)
        norm_conf = data_utils.normalize_conf(conf)
        f1_weight_list = [1.0] * num_classes
        f1_weights = np.array(f1_weight_list)
        f1 = data_utils.get_f1_score(conf, f1_weights)
        kappa = data_utils.computeKappa(conf)
        wt_f1 = data_utils.computeWeightedF1(conf)
        print('file: {}, f1: {:.3f}, kappa: {:.3f}, weighted-F1: {:.3f}'.format(
            f, f1, kappa, wt_f1))
        plot_conf(norm_conf)

plot_train_conf(FDIR, num_classes=NUM_CLASSIFY)
[[3203   27   96   51   40    8 1447   14   60   77]
 [  36 3452    1    9    0    6    4   11   20 1420]
 [ 133    4 2930  239  116 1363  100   60    8   12]
 [  60   14   80 2819   67  398   85   52 1443   26]
 [  21    4   78   86 3222   81   91 1394    7    4]
 [  48    3 1398  429  146 2827   63  108    5    8]
 [  37    6   88   89 1395   47 3265   19    3    0]
 [  15 1463   24   57  133   78    4 3205    3   30]
 [  81   27   36 1225   28  144   39   24 3317   35]
 [1450  173   28   23    9    7    5   19   55 3300]]
file: <built-in method f of mtrand.RandomState object at 0x7f616712f5f0>, f1: 0.631, kappa: 0.414, weighted-F1: 0.631
In [72]:
best_f_scores(FDIR, num_classes=NUM_CLASSIFY)
file: checkpoint_1.json, f1: 0.018, kappa: -0.001, weighted-F1: 0.018
file: checkpoint_2.json, f1: 0.123, kappa: -0.034, weighted-F1: 0.123
file: checkpoint_3.json, f1: 0.206, kappa: 0.018, weighted-F1: 0.206
file: checkpoint_4.json, f1: 0.314, kappa: 0.173, weighted-F1: 0.314
file: checkpoint_5.json, f1: 0.417, kappa: 0.234, weighted-F1: 0.417
file: checkpoint_6.json, f1: 0.557, kappa: 0.558, weighted-F1: 0.557
file: checkpoint_7.json, f1: 0.669, kappa: 0.668, weighted-F1: 0.669
file: checkpoint_8.json, f1: 0.702, kappa: 0.714, weighted-F1: 0.702
file: checkpoint_9.json, f1: 0.695, kappa: 0.734, weighted-F1: 0.695
file: checkpoint_10.json, f1: 0.733, kappa: 0.739, weighted-F1: 0.733
file: checkpoint_11.json, f1: 0.746, kappa: 0.708, weighted-F1: 0.746
file: checkpoint_12.json, f1: 0.791, kappa: 0.787, weighted-F1: 0.791
file: checkpoint_13.json, f1: 0.783, kappa: 0.787, weighted-F1: 0.783
file: checkpoint_14.json, f1: 0.806, kappa: 0.800, weighted-F1: 0.806
file: checkpoint_15.json, f1: 0.800, kappa: 0.803, weighted-F1: 0.800
file: checkpoint_16.json, f1: 0.813, kappa: 0.819, weighted-F1: 0.813
file: checkpoint_17.json, f1: 0.805, kappa: 0.809, weighted-F1: 0.805
file: checkpoint_18.json, f1: 0.805, kappa: 0.813, weighted-F1: 0.805
file: checkpoint_19.json, f1: 0.829, kappa: 0.828, weighted-F1: 0.829
file: checkpoint_20.json, f1: 0.836, kappa: 0.824, weighted-F1: 0.836
file: checkpoint_21.json, f1: 0.844, kappa: 0.835, weighted-F1: 0.844
file: checkpoint_22.json, f1: 0.844, kappa: 0.834, weighted-F1: 0.844
file: checkpoint_23.json, f1: 0.845, kappa: 0.836, weighted-F1: 0.845
file: checkpoint_24.json, f1: 0.846, kappa: 0.839, weighted-F1: 0.846
file: checkpoint_25.json, f1: 0.846, kappa: 0.839, weighted-F1: 0.846
file: checkpoint_26.json, f1: 0.845, kappa: 0.839, weighted-F1: 0.845
file: checkpoint_27.json, f1: 0.847, kappa: 0.836, weighted-F1: 0.847
file: checkpoint_28.json, f1: 0.845, kappa: 0.839, weighted-F1: 0.845
file: checkpoint_29.json, f1: 0.848, kappa: 0.842, weighted-F1: 0.848
file: checkpoint_30.json, f1: 0.847, kappa: 0.842, weighted-F1: 0.847
file: checkpoint_31.json, f1: 0.848, kappa: 0.844, weighted-F1: 0.848
file: checkpoint_32.json, f1: 0.850, kappa: 0.843, weighted-F1: 0.850
file: checkpoint_33.json, f1: 0.851, kappa: 0.843, weighted-F1: 0.851
file: checkpoint_34.json, f1: 0.846, kappa: 0.841, weighted-F1: 0.846
file: checkpoint_35.json, f1: 0.846, kappa: 0.840, weighted-F1: 0.846
file: checkpoint_36.json, f1: 0.846, kappa: 0.842, weighted-F1: 0.846
file: checkpoint_37.json, f1: 0.849, kappa: 0.842, weighted-F1: 0.849
file: checkpoint_38.json, f1: 0.850, kappa: 0.844, weighted-F1: 0.850
file: checkpoint_39.json, f1: 0.848, kappa: 0.842, weighted-F1: 0.848
file: checkpoint_40.json, f1: 0.848, kappa: 0.846, weighted-F1: 0.848
file: checkpoint_41.json, f1: 0.851, kappa: 0.849, weighted-F1: 0.851
file: checkpoint_42.json, f1: 0.849, kappa: 0.843, weighted-F1: 0.849
file: checkpoint_43.json, f1: 0.850, kappa: 0.845, weighted-F1: 0.850
file: checkpoint_44.json, f1: 0.849, kappa: 0.845, weighted-F1: 0.849
file: checkpoint_45.json, f1: 0.852, kappa: 0.851, weighted-F1: 0.852
file: checkpoint_46.json, f1: 0.852, kappa: 0.850, weighted-F1: 0.852
file: checkpoint_47.json, f1: 0.850, kappa: 0.850, weighted-F1: 0.850
file: checkpoint_48.json, f1: 0.847, kappa: 0.845, weighted-F1: 0.847
file: checkpoint_49.json, f1: 0.851, kappa: 0.846, weighted-F1: 0.851
file: checkpoint_50.json, f1: 0.850, kappa: 0.847, weighted-F1: 0.850
file: checkpoint_51.json, f1: 0.853, kappa: 0.852, weighted-F1: 0.853
file: checkpoint_52.json, f1: 0.851, kappa: 0.848, weighted-F1: 0.851
file: checkpoint_53.json, f1: 0.852, kappa: 0.851, weighted-F1: 0.852
file: checkpoint_54.json, f1: 0.851, kappa: 0.849, weighted-F1: 0.851
file: checkpoint_55.json, f1: 0.850, kappa: 0.849, weighted-F1: 0.850
file: checkpoint_56.json, f1: 0.853, kappa: 0.851, weighted-F1: 0.853
file: checkpoint_57.json, f1: 0.852, kappa: 0.852, weighted-F1: 0.852
file: checkpoint_58.json, f1: 0.851, kappa: 0.851, weighted-F1: 0.851
file: checkpoint_59.json, f1: 0.853, kappa: 0.850, weighted-F1: 0.853
file: checkpoint_60.json, f1: 0.853, kappa: 0.855, weighted-F1: 0.853
file: checkpoint_61.json, f1: 0.853, kappa: 0.853, weighted-F1: 0.853
file: checkpoint_62.json, f1: 0.852, kappa: 0.854, weighted-F1: 0.852
file: checkpoint_63.json, f1: 0.855, kappa: 0.854, weighted-F1: 0.855
file: checkpoint_64.json, f1: 0.852, kappa: 0.852, weighted-F1: 0.852
file: checkpoint_65.json, f1: 0.854, kappa: 0.852, weighted-F1: 0.854
file: checkpoint_66.json, f1: 0.855, kappa: 0.855, weighted-F1: 0.855
file: checkpoint_67.json, f1: 0.853, kappa: 0.856, weighted-F1: 0.853
file: checkpoint_68.json, f1: 0.853, kappa: 0.854, weighted-F1: 0.853
file: checkpoint_69.json, f1: 0.853, kappa: 0.854, weighted-F1: 0.853
file: checkpoint_70.json, f1: 0.853, kappa: 0.854, weighted-F1: 0.853
file: checkpoint_71.json, f1: 0.854, kappa: 0.854, weighted-F1: 0.854
file: checkpoint_72.json, f1: 0.855, kappa: 0.855, weighted-F1: 0.855
file: checkpoint_73.json, f1: 0.853, kappa: 0.852, weighted-F1: 0.853
file: checkpoint_74.json, f1: 0.853, kappa: 0.853, weighted-F1: 0.853
file: checkpoint_75.json, f1: 0.855, kappa: 0.853, weighted-F1: 0.855
file: checkpoint_76.json, f1: 0.854, kappa: 0.851, weighted-F1: 0.854
file: checkpoint_77.json, f1: 0.856, kappa: 0.854, weighted-F1: 0.856
file: checkpoint_78.json, f1: 0.852, kappa: 0.852, weighted-F1: 0.852
file: checkpoint_79.json, f1: 0.852, kappa: 0.852, weighted-F1: 0.852
file: checkpoint_80.json, f1: 0.853, kappa: 0.853, weighted-F1: 0.853
file: checkpoint_81.json, f1: 0.854, kappa: 0.853, weighted-F1: 0.854
file: checkpoint_82.json, f1: 0.854, kappa: 0.857, weighted-F1: 0.854
file: checkpoint_83.json, f1: 0.855, kappa: 0.853, weighted-F1: 0.855
file: checkpoint_84.json, f1: 0.854, kappa: 0.853, weighted-F1: 0.854
file: checkpoint_85.json, f1: 0.855, kappa: 0.853, weighted-F1: 0.855
file: checkpoint_86.json, f1: 0.854, kappa: 0.852, weighted-F1: 0.854
file: checkpoint_87.json, f1: 0.855, kappa: 0.854, weighted-F1: 0.855
file: checkpoint_88.json, f1: 0.856, kappa: 0.856, weighted-F1: 0.856
file: checkpoint_89.json, f1: 0.853, kappa: 0.854, weighted-F1: 0.853
file: checkpoint_90.json, f1: 0.854, kappa: 0.853, weighted-F1: 0.854
file: checkpoint_91.json, f1: 0.855, kappa: 0.852, weighted-F1: 0.855
file: checkpoint_92.json, f1: 0.853, kappa: 0.853, weighted-F1: 0.853
file: checkpoint_93.json, f1: 0.854, kappa: 0.854, weighted-F1: 0.854
file: checkpoint_94.json, f1: 0.855, kappa: 0.855, weighted-F1: 0.855
file: checkpoint_95.json, f1: 0.854, kappa: 0.854, weighted-F1: 0.854
file: checkpoint_96.json, f1: 0.855, kappa: 0.855, weighted-F1: 0.855
file: checkpoint_97.json, f1: 0.853, kappa: 0.852, weighted-F1: 0.853
file: checkpoint_98.json, f1: 0.854, kappa: 0.853, weighted-F1: 0.854
file: checkpoint_99.json, f1: 0.853, kappa: 0.853, weighted-F1: 0.853
file: checkpoint_100.json, f1: 0.854, kappa: 0.855, weighted-F1: 0.854
file: checkpoint_101.json, f1: 0.854, kappa: 0.855, weighted-F1: 0.854
Out[72]:
([0.85522590806395893, 0.85549053209529924, 0.85398794020671587],
 [array([[864,   7,  24,   6,   8,   3,   5,   6,  51,  26],
         [  5, 953,   3,   4,   1,   1,   1,   3,  13,  16],
         [ 43,   0, 801,  31,  42,  32,  40,   6,   5,   0],
         [ 13,   2,  45, 753,  31,  87,  42,  16,   9,   2],
         [  4,   1,  29,  28, 874,  14,  33,  16,   1,   0],
         [  6,   1,  70, 149,  33, 703,  13,  25,   0,   0],
         [  6,   1,  22,  23,   6,   5, 935,   2,   0,   0],
         [  9,   1,  17,  17,  67,  28,   2, 856,   0,   3],
         [ 14,   7,   7,  17,   6,   1,   1,   0, 938,   9],
         [ 13,  68,   4,  10,   3,   1,   3,   1,  17, 880]]),
  array([[860,   6,  29,   5,  12,   2,   6,   9,  46,  25],
         [  4, 948,   3,   4,   1,   1,   1,   5,  13,  20],
         [ 36,   0, 807,  26,  45,  32,  42,   7,   5,   0],
         [ 12,   2,  48, 738,  37,  93,  43,  19,   6,   2],
         [  2,   1,  28,  23, 879,  14,  36,  16,   1,   0],
         [  6,   1,  68, 138,  36, 712,  13,  26,   0,   0],
         [  6,   0,  23,  18,   7,   5, 939,   2,   0,   0],
         [  6,   1,  17,  18,  68,  24,   2, 863,   0,   1],
         [ 16,   8,   7,  17,   8,   1,   2,   0, 929,  12],
         [ 14,  63,   4,   9,   3,   1,   3,   2,  16, 885]]),
  array([[851,   5,  33,   5,  14,   3,   7,   8,  48,  26],
         [  5, 948,   3,   5,   1,   2,   1,   3,  13,  19],
         [ 34,   0, 812,  28,  44,  31,  41,   6,   4,   0],
         [ 10,   1,  50, 734,  39,  93,  45,  17,   9,   2],
         [  1,   1,  28,  21, 883,  13,  37,  15,   1,   0],
         [  4,   1,  75, 136,  38, 708,  13,  25,   0,   0],
         [  6,   0,  24,  20,   8,   5, 935,   2,   0,   0],
         [  6,   1,  17,  17,  75,  29,   2, 852,   0,   1],
         [ 12,   7,   7,  18,   8,   1,   3,   0, 936,   8],
         [ 14,  62,   3,   9,   4,   1,   3,   3,  16, 885]])],
 ['checkpoint_85.json', 'checkpoint_87.json', 'checkpoint_101.json'])